#!/usr/bin/env python2
#
# A simple python client to speak to the donbL AVR boot loader
# donb (donb@capitolhillconsultants.com)
# October 7th, 2011
#

from intelhex import IntelHex
import serial
import string
import array
import time
import re
import os
import sys

# defaults
#
TIMEOUT=5

global s
global psize
global signature
global calibration

# send header
#
def sendheader(t, l):
	global s
	if type(t) is int:
		s.write(chr(t))
	else:
		s.write(t)
	s.write(chr((l>>8)&0xff))
	s.write(chr(l&0xff))
	return 1

# error names
#
def errname(x):
	if ord(x) == 1:
		return "BL_ERR_UNIMP"
	elif ord(x) == 2:
		return "BL_ERR_INVALID"
	elif ord(x) == 3:
		return "BL_ERR_CHECKSUM"
	return ord(x)

# get header
#
def getheader():
	global s

	# status byte
	e = s.read()

	# len
	x = s.read()
	y = s.read()
	x = ((ord(x) << 8) | ord(y))

	if ord(e) != 0:
		print "getheader: BL_ERR_OK not found: ", errname(e)

	return x

# get value
#
def getvalue(x):
	z = array.array('B')
	while(x):
		x -= 1
		z.append(ord(s.read()))

	return z

# address
#
def addr(x):
	sendheader(5, 4)

	# address is always in hex
	if type(x) == str:
		x = string.atol(x, 16)

	s.write(chr((x >> 24)&0xff))
	s.write(chr((x >> 16)&0xff))
	s.write(chr((x >>  8)&0xff))
	s.write(chr((x >>  0)&0xff))

	#print "addr: sending %x" % x

	x = getheader()
	if x != 4:
		print "error: address setting failed"
		return 0

	x = getvalue(x)
	#print "address received: %.02x%.02x%.02x%.02x" % (x[0], x[1], x[2], x[3])
	#print "address set"

	return 1

# peek a word at an address
#
def peek(a):
	addr(a)

	sendheader(6, 0)

	x = getheader()
	if x != 2:
		print "peek failed: message count incorrect"
		return 0

	x = getvalue(x)
	x = ((x[0] << 8)|x[1])
	print "peek %s: %.04x" % (a, x)

	return 1

# get the page size
#
def pagesz():
	global psize

	sendheader(3, 0)

	x = getheader()
	if x != 2:
		print "pagesz failed: message count incorrect"
		return 0

	x = getvalue(x)
	x = ((x[0] << 8) | x[1])
	print "pagesz: %x" % x
	psize = x

	return 1

# get the signature bytes and calibration byte
#
def getsignature():
	sendheader(1, 0)

	x = getheader()
	if x != 4:
		print "error: signature header size invalid"
		return 0

	x = getvalue(x)
	signature = x[0:3]
	calibration = x[3]
	print "signature: %.02x %.02x %.02x" % (x[0], x[1], x[2])
	print "rc calibration: %.02x" % (x[3])

	return 1

# get the fuse bytes
#
def fuse():
	sendheader(2, 0)

	x = getheader()
	if x != 4:
		print "fuse failed: message count incorrect"
		return 0

	x = getvalue(x)
	print "fuse: %.02x %.02x %.02x %.02x" % (x[0], x[1], x[2], x[3])

	return 1

# exit the BLS command loop
#
def doexit():
	sendheader(7, 0)
	x = getheader()
	if x != 0:
		print "error: exit attempt failed"
		return 0

	return 1

# initialize the SIM
#
def blsinit(p, b):
	global s

	# make sure that RST is tied to RESET on the AVR so we can force BLS to execute by driving RST low
	s = serial.Serial(p, b, timeout=TIMEOUT, parity=serial.PARITY_NONE, stopbits=serial.STOPBITS_ONE, rtscts=0)
	s.setRTS(1)
	s.setDTR(1)
	time.sleep(0.01)
	s.flushInput()
	s.setRTS(0)
	s.setDTR(0)
	time.sleep(0.01)

	# based on a (presumed) bug in the AVR architecture, the BL wont
	# accept USART input unless it first outputs a byte. Not sure why
	# this occurs, but here is the compensating code:
	x = s.read()
	if x == None:
		print "error: no paramedics available"
		return 0
	if x != '+':
		print "error: unexpected value: '+' != %.02x" % (ord(x))
		return 0

	print "boot loader found. entering command mode."
	s.write('?')

	return 1

# upload to bootloader
#
def upload(p, d):
	# send the address first
	if addr(p) != 1:
		print "error: couldnt set page address"
		return 0

	sendheader(4, len(d))

	# ship the payload
	v = 0
	for i in d:
		v += i
		s.write(chr(i))

	# twos complement the sum
	v = (0x100 - v) & 0xff

	x = getheader()
	if x != 1:
		print "error: flash failed"
		return 0

	x = getvalue(x)
	if v != x[0]:
		print "error: checksums dont match: got=%.02x, expected=%.02x" % (x[0], v)
		return 0

	return 1

# flash an image to the chip
#
def flash(f):
	global psize

	# make sure we have the page size first
	pagesz()

	# ingest the ihex file
	h = IntelHex(f)

	sys.stdout.write("flashing pages: ")

	# always start with page zero; the code will auto-jump to
	# the new address if there is no page zero in the hexfile
	p = 0
	d = []
	n = -1
	for a in h.addresses():
		# if the page has changed, upload
		if (a & ~((psize)-1)) != p:
			# upload only if we have data
			if len(d) > 0:
				if upload(p, d) != 1:
					print "error: upload failed"
					return 0
				sys.stdout.write(".")

			# set the new page
			p = (a & ~((psize)-1))

			# start a new list with this address
			d = []

			# reset the previous pointer
			n = -1

		# insert the data at the page address
		t = a & ((psize)-1)

		# if this address isn't an increment above last, fill the void
		while n + 1 < t:
			d.insert(n, 0xff)
			n += 1

		# ensure n reflects this byte
		n = t

		# insert the actual data byte
		d.insert(t, ord(h.gets(a, 1))) 

	# if iterating through addresses() has completed but we still have data, 
	# upload it now
	if len(d) > 0:
		if upload(p, d) != 1:
			print "error: upload failed"
			return 0
		sys.stdout.write(".")

	print ""

	return 1

# flash an image from the web
#
def fromweb():
	global signature

	# get the device ID
	if getsignature() != 1:
		print "error: can't retrieve chip signature"
		return 0

	fn="/tmp/.goodfet.hex"
	print "fromweb: retrieving image for chip %.02x.%.02x.%.02x" % (signature[0], signature[1], signature[2])
	os.system("curl http://pa-ri.sc/goodfet/fw/%.02x%.02x%.02x.hex > %s" % (signature[0], signature[1], signature[2], fn))

	return flash(fn)

# main
#
b = 500000
p = "/dev/ttyUSB0"
if len(sys.argv) > 1:
	p = sys.argv[1]
if len(sys.argv) > 2:
	b = sys.argv[2]

if blsinit(p, b) != 1:
	print "error: couldnt initialize bls"
	exit(1)

while 1:
	sys.stdout.write("donbL> ")

	c = sys.stdin.readline().rstrip()
	c = re.split("[ \t]", c)

	if c[0] == "exit":
		print "donbL: exiting"
		doexit()
		s.close()
		break

	if c[0] == "peek":
		print "donbL: peeking address"
		peek(c[1])

	if c[0] == "address":
		print "donbL: setting flash address"
		addr(c[1])

	if c[0] == "pagesz":
		print "donbL: retrieving page size"
		pagesz()

	if c[0] == "signature":
		print "donbL: retrieving avr signature"
		getsignature()

	if c[0] == "fuse":
		print "donbL: retrieving avr fuse and lock bytes"
		fuse()

	if c[0] == "flash":
		print "donbL: flashing image"
		flash(c[1])

	if c[0] == "fromweb":
		print "donbL: flashing new image from web"
		fromweb()

	if c[0] == "reset":
		s.close()
		blsinit(p, b)

